
rm(list=ls())
library(ctgt)
library(stringr)
library(xlsx)


n= 100
m = 5
X = matrix(0, n, m,byrow = T )
for ( i in 1:n){
  set.seed(1234+i)
  X[i,] =  as.vector(arima.sim(model = list(order = c(1, 0, 0), ar = 0.2), n = m) )
}

y = rbinom(n,1,0.6)
X[which(y==1),1:3] = X[which(y==1),1:3] + 0.8


#cor(X[,1:50],X[,1:50])
library(stringr)
xs = str_replace_all(paste(rep("x",m),seq(1,m,1)),fixed(" "), "")
colnames(X) = xs

sqrW = sqrt(mean(y)*(1-mean(y)) ) # sqrt of covariance of y

WIHZ = sqrW *(sweep(X,2,colMeans(X))) ## W^{1/2}*(I-H)Z
IHZ = WIHZ/sqrW ##(I-H)Z
# the full model
Tf = sum(colSums(y*IHZ[,xs,drop=F])^2)# full test statistic = y^t (I-H) Z Z^t (I-H) y
Lamf = round(eigen(tcrossprod(WIHZ[,xs,drop=F]),symmetric = T,only.values = T)$values,8) # eigen(W^{1/2}*(I-H)Z%*%Z^t%*%(I-H)%*%W^{1/2}) 
Cf = criticalvalue(Lamf )
tt = sapply(1:m, function(x) sum(colSums(y*IHZ[,x,drop=F])^2) )

ww = sapply(1:m, function(x) sum(round(eigen(tcrossprod(WIHZ[,x,drop=F]),symmetric = T,only.values = T)$values,8) ) )

q = tt/ww
names(q) = xs
q = sort(q)


#H3#
load("gmincmax3.Rdata")
R = "x3"
Tr = sum(colSums(y*IHZ[,R,drop=F])^2)
Lamr = round(eigen(tcrossprod(WIHZ[,R,drop=F]),symmetric = T,only.values = T)$values,8) 
Cr = criticalvalue(Lamr )
ts = tc3$tmin[c(1,4,11,15,16)]
ls = tc3$level[c(1,4,11,15,16)]
min(which(ts>= Cf)) ## 

qs = q[setdiff(names(q),R)]
qs

c1=Cf
lf = sum(Lamf)
lr = sum(Lamr)
k = min(which(ts>= c1))
lres = lf
cres = c1
while(k > 1 && abs(lf-lr)>1e-4){
  lf = (c1-ts[k-1])/qs[k-1] + ls[k-1]
  lres =c(lres,lf)
  cres =c(cres,c1)
  c1 = criticalvalue(getL(Lamf,Lamr,lf) )
  k = min(which(ts>= c1))
  if(ts[1] > c1){
    cres =c(cres,c1)
    lres = c(lres,sum(Lamr))
    break
  }else{
    lr = (c1-ts[k-1])/qs[k-1] + ls[k-1]
  }
}

cross3 = cbind(lres,cres)
save(cross3,file ="cross3.Rdata")

#H2#
load("gmincmax2.Rdata")
R = "x2"
Tr = sum(colSums(y*IHZ[,R,drop=F])^2)
Lamr = round(eigen(tcrossprod(WIHZ[,R,drop=F]),symmetric = T,only.values = T)$values,8) 
Cr = criticalvalue(Lamr )
ts = tc2$tmin[c(1,4,11,14,16)]
ls = tc2$level[c(1,4,11,14,16)]
min(which(ts>= Cf)) ## 

qs = q[setdiff(names(q),R)]
qs

c1=Cf
lf = sum(Lamf)
lr = sum(Lamr)
k = min(which(ts>= c1))
lres = lf
cres = c1
while(k > 1 && abs(lf-lr)>1e-4){
  lf = (c1-ts[k-1])/qs[k-1] + ls[k-1]
  lres =c(lres,lf)
  cres =c(cres,c1)
  c1 = criticalvalue(getL(Lamf,Lamr,lf) )
  k = min(which(ts>= c1))
  if(ts[1] > c1){
    cres =c(cres,c1)
    lres = c(lres,sum(Lamr))
    break
  }else{
    lr = (c1-ts[k-1])/qs[k-1] + ls[k-1]
  }
}

lres
cres
cross2 = cbind(lres,cres)
save(cross2,file ="cross2.Rdata")

#H1#
load("gmincmax1.Rdata")
R = "x1"
Tr = sum(colSums(y*IHZ[,R,drop=F])^2)
Lamr = round(eigen(tcrossprod(WIHZ[,R,drop=F]),symmetric = T,only.values = T)$values,8) 
Cr = criticalvalue(Lamr )
ts = tc1$tmin[c(1,4,11,14,16)]
ls = tc1$level[c(1,4,11,14,16)]
min(which(ts>= Cf)) ## 

qs = q[setdiff(names(q),R)]
qs

c1=Cf
lf = sum(Lamf)
lr = sum(Lamr)
k = min(which(ts>= c1))
lres = lf
cres = c1
while(k > 1 && abs(lf-lr)>1e-4){
  lf = (c1-ts[k-1])/qs[k-1] + ls[k-1]
  lres =c(lres,lf)
  cres =c(cres,c1)
  c1 = criticalvalue(getL(Lamf,Lamr,lf) )
  k = min(which(ts>= c1))
  if(ts[1] > c1){
    cres =c(cres,c1)
    lres = c(lres,sum(Lamr))
    break
  }else{
    lr = (c1-ts[k-1])/qs[k-1] + ls[k-1]
  }
}

lres
cres
cross1 = cbind(lres,cres)
save(cross1,file ="cross1.Rdata")
